from functools import wraps
from itertools import chain
from threading import get_ident
from collections import defaultdict
from typing import Callable, Iterable
import datetime
import json
import re

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Query as _Query
from mistune import HTMLRenderer, Markdown

from zvms.res import *

markdown = Markdown(HTMLRenderer())
rule_remove_links = re.compile(r'<a.*?>(.*?)</a>', re.S)

def render_markdown(md):
    return rule_remove_links.sub(r'<a>\1</a>', markdown.parse(md))

_sessions = {}

def get_session(engine):
    if engine not in _sessions:
        _sessions[engine] = defaultdict(sessionmaker(engine))
    return _sessions[engine][get_ident()]

def commit():
    for eng in _sessions.keys():
        get_session(eng).commit()

def rollback():
    for eng in _sessions.keys():
        get_session(eng).rollback()

def close():
    for eng in _sessions.keys():
        get_session(eng).close()

class _QueryProperty:
    def __get__(self, obj, cls):
        return Query(get_session(cls.__engine__).query(cls))

def foo(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if isinstance(self, Iterable):
            return (func(i, *args, **kwargs) for i in self)
        return func(self, *args, **kwargs)
    return wrapper

@foo
def select(self, *cols, **aliases):
    return dict(zip(chain(cols, aliases.values()), (getattr(self, i) for i in chain(cols, aliases.keys()))))

@foo
def update(self, **updates):
    for k, v in updates.items():
        if isinstance(v, Callable):
            v = v(getattr(self, k))
        setattr(self, k, v)
    self.on_update()

@foo
def insert(self):
    get_session(self.__engine__).add(self)
    get_session(self.__engine__).flush()
    self.on_insert()
    return self

def incr(amount):
    return lambda x: x + amount
    
def select_value(self, col):
    return map(lambda x: getattr(x, col), self)

class Query:
    select = select
    update = update
    insert = insert
    select_value = select_value

    def delete(self):
        for item in self:
            item.on_delete()
        self.__query.delete()

    def __init__(self, query):
        self.__query = query

    def get_or_error(self, ident, message='未查询到相关信息'):
        ret = self.__query.get(ident)
        if not ret:
            raise ZvmsError(message)
        return ret

    def first_or_error(self, message='未查询到相关信息'):
        ret = self.__query.first()
        if not ret:
            raise ZvmsError(message)
        return ret

    def __iter__(self):
        return self.__query.__iter__()

    def __getattr__(self, *args, **kwargs):
        return Query.__deco(self.__query.__getattribute__(*args, **kwargs))

    def __deco(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            ret = func(*args, **kwargs)
            return Query(ret) if isinstance(ret, _Query) else ret
        return wrapper

class ModelMixIn:
    select = select
    update = update
    insert = insert

    def on_insert(self):
        pass

    def on_update(self):
        pass

    def on_delete(self):
        pass

    query = _QueryProperty()

def success(message, **kwresult):
    ret = {'type': 'SUCCESS', 'message': message} | kwresult
    commit()
    close()
    return json.dumps(ret)

def error(message):
    rollback()
    close()
    return json.dumps({'type': 'ERROR', 'message': message})

not_found = error('未查询到相关记录')

class ZvmsError(Exception):
    def __init__(self, message):
        self.message = message

def try_parse_time(str):
    try:
        return datetime.datetime.strptime(str, '%Y-%m-%d %H:%M:%S')
    except ValueError:
        raise ZvmsError('请求接口错误: 非法的时间字符串')

def count(seq, predicate):
    ret = 0
    for i in seq:
        if predicate(i):
            ret += 1
    return ret

def exists(seq, predicate):
    for i in seq:
        if predicate(i):
            return True
    return False

thisYear = 2022

def classIdToString(a):
    id = int(a)
    _year = id // 100
    _class = id % 100
    ret = ""
    # 特殊身份的判断 # 这些东西要放到文档里
    # 教师 100001 100002
    # 管理员 110001 110002
    # 系统 120003 120004
    # 超管 130001
    if _year//100 == 10:
        ret = "教师"
        return ret
    elif _year//100 == 11:
        ret = "管理员"
        return ret
    elif _year//100 == 12:
        ret = "系统"
        return ret
    elif _year//100 == 13:
        ret = "超管"
        return ret
    
    if _class <= 9:
        ret = ret + "高"
    elif _class <= 17:
        ret = ret + "蛟"
    if _year == thisYear:
        ret = ret + "一"
    elif _year == thisYear - 1:
        ret = ret + "二"
    elif _year == thisYear - 2:
        ret = ret + "三"
    ret = ret + (["NULL","1","2","3","4","5","6","7","8","9","NULL","1","2","3","4","5","6","7"])[_class] #如果我没记错的话校徽是这样的
    ret = ret + "班"
    if "NULL" in ret: ret = "无班级"
    return ret